#include <os/debug.h>
#include <os/driver.h>
#include <os/initcall.h>
#include <os/diskman.h>
#include <os/virmem.h>
#include <sys/ioctl.h>
#include <lib/string.h>
#include <lib/bitop.h>
#include <os/task.h>
#include <driver/ramdisk.h>

iostatus_t RamDiskEnter(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;
    device_object_t *devobj;
    device_extension_t *extension;

    status = IoCreateDevice(driver, sizeof(device_extension_t), DEVICE_NAME, DEVICE_TYPE_VIRTUAL_DISK, &devobj);
    if (status != IO_SUCCESS)
    {
        KPrint(PRINT_ERR "%s: create device failed!\n");
        status = IO_FAILED;
        IoDeleteDevice(devobj);
        return status;
    }
    // neither io mode
    devobj->flags = 0;
    extension = (device_extension_t *)devobj->device_extension;
    extension->sectors = RAMDISK_SECTOR;
    extension->len = extension->sectors * SECTOR_SIZE;
    extension->buffer = VirMemAlloc(extension->len);
    if(!extension->buffer)
        return IO_FAILED;
    extension->rwoff = 0;
    
    KPrint("[ramdisk] size %d MB buff %x\n", (uint32_t)extension->len / MB, (uint32_t)extension->buffer);
    // register disk device
    DiskAdd(devobj, DISK_TYPE_DISK);
    return status;
}

iostatus_t RamDiskExit(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;
    device_object_t *devobj, *next;

    // delete all device object
    list_traversal_all_owner_to_next_safe(devobj, next, &driver->device_list, list)
    {
        IoDeleteDevice(devobj);
    }
    // del driver name
    string_del(&driver->name);
    return IO_SUCCESS;
}

iostatus_t RamDiskOpen(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;

    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

iostatus_t RamDiskClose(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;

    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

iostatus_t RamDiskRead(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;
    int info = 0;
    uint32_t off = ioreq->parame.read.offset;
    uint32_t len = ioreq->parame.read.len;
    device_extension_t *extension = device->device_extension;

    if(off==DISKOFF_MAX)
        off=extension->rwoff;

    if (len / SECTOR_SIZE + off >= extension->sectors) // above limit
    {
        KPrint("[ramdisk] disk space above\n");
        status = IO_FAILED;
    }
    else
    {
        memcpy(ioreq->user_buff, extension->buffer + off * SECTOR_SIZE, len);
        info = len;
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = info;
    IoCompleteRequest(ioreq);
    return status;
}

iostatus_t RamDiskWrite(device_object_t *device, io_request_t *ioreq)
{
    iostatus_t status = IO_SUCCESS;
    uint32_t off = ioreq->parame.write.offset;
    uint32_t len = ioreq->parame.write.len;
    device_extension_t *extension = device->device_extension;
    int info = 0;

    if(off==DISKOFF_MAX)
        off=extension->rwoff;

    if (len / SECTOR_SIZE + off >= extension->sectors) // above limit
    {
        KPrint("[ramdisk] disk space above\n");
        status = IO_FAILED;
    }
    else
    {
        memcpy(extension->buffer + off * SECTOR_SIZE, ioreq->user_buff, len);
        info = len;
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = info;
    IoCompleteRequest(ioreq);
    return status;
}

iostatus_t RamDiskDevCtl(device_object_t *device, io_request_t *ioreq)
{
    device_extension_t *extension = device->device_extension;
    iostatus_t status = IO_SUCCESS;
    uint32_t cmd = ioreq->parame.devctl.code;
    uint32_t arg = ioreq->parame.devctl.arg;

    switch (cmd)
    {
    case DISKIO_SETOFF:
    {
        int off = *(uint32_t *)arg;
        if (off > extension->sectors - 1)
            off = extension->sectors - 1;
        extension->rwoff = off;
    }
    break;
    case DISKIO_GETOFF:
        *(uint32_t *)arg = extension->rwoff;
        break;
    case DISKIO_GETSIZE:
        *(uint32_t *)arg = extension->sectors;
        break;
    case DISKIO_CLEAN:
        memset(extension->buffer, 0, extension->len);
        break;
    default:
        status = IO_FAILED;
        break;
    }
    ioreq->io_status.status = status;
    ioreq->io_status.info = 0;
    IoCompleteRequest(ioreq);
    return status;
}

iostatus_t RamDiskDriverFunc(driver_object_t *driver)
{
    iostatus_t status = IO_SUCCESS;

    // bind function
    driver->driver_enter = RamDiskEnter;
    driver->driver_exit = RamDiskExit;

    driver->dispatch_fun[IOREQ_OPEN] = RamDiskOpen;
    driver->dispatch_fun[IOREQ_CLOSE] = RamDiskClose;
    driver->dispatch_fun[IOREQ_READ] = RamDiskRead;
    driver->dispatch_fun[IOREQ_WRITE] = RamDiskWrite;
    driver->dispatch_fun[IOREQ_DEVCTL] = RamDiskDevCtl;

    // set driver name
    string_new(&driver->name, DRIVER_NAME, DRIVER_NAME_LEN);

    return status;
}

static __init void RamDiskDriverEntry()
{
    KPrint("[driver] init ramdisk driver\n");
    if (DriverObjectCreate(RamDiskDriverFunc) < 0)
    {
        KPrint(PRINT_ERR "[driver] %s: driver %s create failed!\n", __func__, DRIVER_NAME);
    }
}

driver_initcall(RamDiskDriverEntry);
